Keep track of DeathMonitor cookies This change keeps track of the objects that the cookies points to so the serviceDied callback knows when it can use the cookie. Test: atest neuralnetworks_utils_hal_aidl_test Tets: atest NeuralNetworksTest_static Bug: 319210610 (cherry picked from https://googleplex-android-review.googlesource.com/q/commit:def7a3cf59fa17ba7faa9af14a24f4161bc276bd) (cherry picked from https://googleplex-android-review.googlesource.com/q/commit:49859a3b5542270363efe42a56b9145142bbfa60) Merged-In: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794 Change-Id: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h index 92ed1cd..9a7fe5e 100644 --- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h +++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
@@ -56,6 +56,8 @@ // Thread safe class class DeathMonitor final { public: + explicit DeathMonitor(uintptr_t cookieKey) : kCookieKey(cookieKey) {} + static void serviceDied(void* cookie); void serviceDied(); // Precondition: `killable` must be non-null. @@ -63,9 +65,18 @@ // Precondition: `killable` must be non-null. void remove(IProtectedCallback* killable) const; + uintptr_t getCookieKey() const { return kCookieKey; } + + ~DeathMonitor(); + DeathMonitor(const DeathMonitor&) = delete; + DeathMonitor(DeathMonitor&&) noexcept = delete; + DeathMonitor& operator=(const DeathMonitor&) = delete; + DeathMonitor& operator=(DeathMonitor&&) noexcept = delete; + private: mutable std::mutex mMutex; mutable std::vector<IProtectedCallback*> mObjects GUARDED_BY(mMutex); + const uintptr_t kCookieKey; }; class DeathHandler final { diff --git a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp index 54a673c..4a7ac08 100644 --- a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp +++ b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp
@@ -25,6 +25,7 @@ #include <algorithm> #include <functional> +#include <map> #include <memory> #include <mutex> #include <vector> @@ -33,6 +34,16 @@ namespace aidl::android::hardware::neuralnetworks::utils { +namespace { + +// Only dereference the cookie if it's valid (if it's in this set) +// Only used with ndk +std::mutex sCookiesMutex; +uintptr_t sCookieKeyCounter GUARDED_BY(sCookiesMutex) = 0; +std::map<uintptr_t, std::weak_ptr<DeathMonitor>> sCookies GUARDED_BY(sCookiesMutex); + +} // namespace + void DeathMonitor::serviceDied() { std::lock_guard guard(mMutex); std::for_each(mObjects.begin(), mObjects.end(), @@ -40,8 +51,24 @@ } void DeathMonitor::serviceDied(void* cookie) { - auto deathMonitor = static_cast<DeathMonitor*>(cookie); - deathMonitor->serviceDied(); + std::shared_ptr<DeathMonitor> monitor; + { + std::lock_guard<std::mutex> guard(sCookiesMutex); + if (auto it = sCookies.find(reinterpret_cast<uintptr_t>(cookie)); it != sCookies.end()) { + monitor = it->second.lock(); + sCookies.erase(it); + } else { + LOG(INFO) + << "Service died, but cookie is no longer valid so there is nothing to notify."; + return; + } + } + if (monitor) { + LOG(INFO) << "Notifying DeathMonitor from serviceDied."; + monitor->serviceDied(); + } else { + LOG(INFO) << "Tried to notify DeathMonitor from serviceDied but could not promote."; + } } void DeathMonitor::add(IProtectedCallback* killable) const { @@ -57,12 +84,25 @@ mObjects.erase(removedIter); } +DeathMonitor::~DeathMonitor() { + // lock must be taken so object is not used in OnBinderDied" + std::lock_guard<std::mutex> guard(sCookiesMutex); + sCookies.erase(kCookieKey); +} + nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) { if (object == nullptr) { return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "utils::DeathHandler::create must have non-null object"; } - auto deathMonitor = std::make_shared<DeathMonitor>(); + + std::shared_ptr<DeathMonitor> deathMonitor; + { + std::lock_guard<std::mutex> guard(sCookiesMutex); + deathMonitor = std::make_shared<DeathMonitor>(sCookieKeyCounter++); + sCookies[deathMonitor->getCookieKey()] = deathMonitor; + } + auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient( AIBinder_DeathRecipient_new(DeathMonitor::serviceDied)); @@ -70,8 +110,9 @@ // STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests // where this is not an error. if (object->isRemote()) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath( - object->asBinder().get(), deathRecipient.get(), deathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_linkToDeath(object->asBinder().get(), deathRecipient.get(), + reinterpret_cast<void*>(deathMonitor->getCookieKey()))); HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed"; } @@ -91,8 +132,9 @@ DeathHandler::~DeathHandler() { if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath( - kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_unlinkToDeath(kObject->asBinder().get(), kDeathRecipient.get(), + reinterpret_cast<void*>(kDeathMonitor->getCookieKey()))); const auto maybeSuccess = handleTransportError(ret); if (!maybeSuccess.ok()) { LOG(ERROR) << maybeSuccess.error().message; diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp index 73727b3..ffd3b8e 100644 --- a/neuralnetworks/aidl/utils/test/DeviceTest.cpp +++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
@@ -697,7 +697,8 @@ const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _)) @@ -846,7 +847,8 @@ const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _)) @@ -970,7 +972,8 @@ const auto mockDevice = createMockDevice(); const auto device = Device::create(kName, mockDevice, kVersion).value(); const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey())); return ndk::ScopedAStatus::ok(); }; EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))